# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import numpy as np
import sympy as sm
from sympy.utilities import group
from sympy.printing.str import StrPrinter, StrReprPrinter
from sympy.printing.latex import LatexPrinter
from hysop.tools.htypes import first_not_None, check_instance, to_tuple
# unicode subscripts for decimal numbers, signs and parenthesis
decimal_subscripts = "₀₁₂₃₄₅₆₇₈₉"
decimal_exponents = "⁰¹²³⁴⁵⁶⁷⁸⁹"
greak = "αβγδεζηθικλμνξοπρςστυφχψω"
Greak = "ΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩ"
signs = "₊₋"
parenthesis = "₍₎"
partial = "∂"
nabla = "∇"
xsymbol = "x"
freq_symbol = greak[12] # nu
[docs]
def round_expr(expr, num_digits=3):
return expr.xreplace(
{
n: round(n, num_digits)
for n in expr.atoms(sm.Float)
.union(expr.atoms(sm.Rational))
.difference(expr.atoms(sm.Integer))
}
)
[docs]
def truncate_expr(expr, maxlen=80):
assert maxlen >= 3
parts = sstr(expr).split(" ")
assert parts
ss = parts.pop(0)
while parts and (len(ss + parts[0]) < maxlen):
ss += parts.pop(0)
if parts:
ss += "..."
return ss
[docs]
class CustomStrPrinter(StrPrinter):
def _print_Derivative(self, expr):
syms = list(reversed(expr.variables))
nvars = len(syms)
if isinstance(expr.expr, (Symbol, Dummy, AppliedUndef)):
content = self._print(expr.expr)
else:
content = f"[{self._print(expr.expr)}]"
prefix = "{}{}{}/{}".format(
partial, exponent(nvars) if nvars > 1 else "", content, partial
)
for sym, num in group(syms, multiple=False):
prefix += "{}{}".format(sym, exponent(num) if num > 1 else "")
return prefix
[docs]
class CustomStrReprPrinter(StrReprPrinter):
pass
[docs]
class CustomLatexPrinter(LatexPrinter):
pass
[docs]
def sstr(expr, **settings):
p = CustomStrPrinter(settings)
return p.doprint(expr)
[docs]
def sstrrepr(expr, **settings):
p = CustomStrReprPrinter(settings)
return p.doprint(expr)
[docs]
def latex(expr, **settings):
p = CustomLatexPrinter(settings)
return p.doprint(expr)
[docs]
def enable_pretty_printing():
sm.Basic.__str__ = sstr
[docs]
class SymbolicBase:
def __new__(cls, name, var_name=None, latex_name=None, pretty_name=None, **kwds):
check_instance(name, str)
check_instance(var_name, str, allow_none=True)
check_instance(latex_name, str, allow_none=True)
check_instance(pretty_name, str, allow_none=True)
try:
obj = super().__new__(cls, name=name, **kwds)
except TypeError:
obj = super().__new__(cls, **kwds)
obj._name = name
obj._var_name = first_not_None(var_name, name)
obj._latex_name = first_not_None(latex_name, name)
obj._pretty_name = first_not_None(pretty_name, name)
return obj
def __init__(self, name, var_name=None, latex_name=None, pretty_name=None, **kwds):
pass
@property
def varname(self):
return self._var_name
def _sympystr(self, printer):
return self._pretty_name
def _latex(self, printer):
return self._latex_name
def _ccode(self, printer):
return self._var_name
def _pretty(self, printer):
return self._pretty_name
def __str__(self):
return self._pretty_name
def __repr__(self):
return self._name
[docs]
class Expr(sm.Expr):
"""Tag for hysop symbolic expressions."""
pass
[docs]
class UnevaluatedExpr(sm.UnevaluatedExpr):
"""Tag for hysop symbolic uneevaluated expressions."""
pass
[docs]
class UnsplittedExpr(Expr):
"""Tag for hysop symbolic unsplitted expressions."""
pass
[docs]
class Symbol(SymbolicBase, sm.Symbol):
"""Tag for hysop symbolic symbols."""
pass
[docs]
class Dummy(SymbolicBase, sm.Dummy):
"""Tag for hysop dummy symbolic variables."""
pass
from sympy.core.function import UndefinedFunction as SympyUndefinedFunction
from sympy.core.function import AppliedUndef as SympyAppliedUndef
[docs]
class UndefinedFunction(SymbolicBase, SympyUndefinedFunction):
"""
Tag for hysop (unapplied) undefined functions.
This is a metaclass.
"""
pass
[docs]
class AppliedUndef(SympyAppliedUndef):
"""Tag for hysop applied undefined functions."""
def _latex(self, printer):
return self._latex_name
def _ccode(self, printer):
return self._var_name
def _pretty(self, printer):
return self._pretty_name
def _sympystr(self, printer):
return self._pretty_name
# def _pretty(self, printer):
# return '{}({})'.format(self._pretty_name,
#','.join(printer._print(a) for a in self.args))
# def _sympystr(self, printer):
# return '{}({})'.format(self._pretty_name,
#','.join(printer._print(a) for a in self.args))
[docs]
def subscript(i, with_sign=False, disable_unicode=False):
"""
Generate an unicode subscript of value i, signs can be enforced.
"""
decimals = "0123456789"
snumber = str(i)
if with_sign:
s0 = snumber[0]
if s0 in decimals:
snumber = "+" + snumber
if disable_unicode:
out = snumber
else:
out = ""
for s in snumber:
if s in decimals:
out += decimal_subscripts[int(s)]
elif s == "+":
out += signs[0]
elif s == "-":
out += signs[1]
else:
out += s
return out
[docs]
def exponent(i, with_sign=False):
"""
Generate an unicode exponent of value i, signs can be enforced.
"""
decimals = "0123456789"
snumber = str(i)
if with_sign:
s0 = snumber[0]
if s0 in decimals:
snumber = "+" + snumber
out = ""
for s in snumber:
if s in decimals:
out += decimal_exponents[int(s)]
elif s == "+":
out += signs[0]
elif s == "-":
out += signs[1]
else:
out += s
return out
[docs]
def subscripts(
ids, sep, with_sign=False, with_parenthesis=False, prefix="", disable_unicode=False
):
"""
Generate a unicode tuple subscript separated by sep,
with or without parenthesis, prefix, and signs.
"""
ids = to_tuple(ids)
if with_parenthesis:
lparen = "(" if disable_unicode else parenthesis[0]
rparen = ")" if disable_unicode else parenthesis[1]
base = "{}{}{}{}" if disable_unicode else "{}{}{}{}"
return base.format(
prefix,
lparen,
sep.join([subscript(i, with_sign, disable_unicode) for i in ids]),
rparen,
)
else:
base = "{}{}" if disable_unicode else "{}{}"
return base.format(
prefix, sep.join([subscript(i, with_sign, disable_unicode) for i in ids])
)
[docs]
def exponents(ids, sep, with_sign=False, with_parenthesis=False, prefix=""):
"""
Generate a unicode tuple exponent separated by sep,
with or without parenthesis, prefix, and signs.
"""
ids = to_tuple(ids)
if with_parenthesis:
return f"{prefix}{parenthesis[0]}{sep.join([exponent(i,with_sign) for i in ids])}{parenthesis[1]}"
else:
return f"{prefix}{sep.join([exponent(i,with_sign) for i in ids])}"
[docs]
def tensor_symbol(
prefix,
shape,
origin=None,
mask=None,
sep=None,
with_parenthesis=False,
force_sign=False,
):
"""
Generate a np.ndarray of sympy.Symbol.
Each of the symbol has given prefix and subscripts are
taken from specified origin if specified or else in matrix/tensor notation.
Other parameters handles subscripts style, see the subscripts() function.
It also returns all generated Symbols as a list.
"""
origin = np.asarray(origin) if origin is not None else np.asarray([0] * len(shape))
sep = sep if sep is not None else ","
with_sign = force_sign or ((origin > 0).any() and len(shape) > 1)
tensor = np.empty(shape=shape, dtype=object)
for idx in np.ndindex(*shape):
if (mask is None) or mask[idx]:
ids = idx - origin
sname = subscripts(
ids,
sep,
with_sign=with_sign,
with_parenthesis=with_parenthesis,
prefix=prefix,
)
tensor[idx] = sm.Symbol(sname, real=True)
else:
tensor[idx] = 0
tensor_vars = tensor.ravel().tolist()
return tensor, tensor_vars
[docs]
def tensor_xreplace(tensor, vars):
"""
Performs an xreplace and all tensor (np.ndarray) elements.
"""
T = tensor.copy()
for idx in np.ndindex(*tensor.shape):
symbol = tensor[idx]
if isinstance(symbol, sm.Expr):
if symbol in vars.keys():
T[idx] = vars[symbol]
elif (hasattr(symbol, "name")) and (symbol.name in vars.keys()):
T[idx] = vars[symbol.name]
else:
T[idx] = symbol.xreplace(vars)
return T
[docs]
def non_eval_xreplace(expr, rule):
"""
Duplicate of sympy's xreplace but with non-evaluate statement included.
"""
if expr in rule:
return rule[expr]
elif rule:
args = []
altered = False
for a in expr.args:
try:
new_a = non_eval_xreplace(a, rule)
except AttributeError:
new_a = a
if new_a != a:
altered = True
args.append(new_a)
args = tuple(args)
if altered:
return expr.func(*args, evaluate=False)
return expr
# Convert powers to mult. in polynomial expressions V
# Example: x^3 -> x*x*x
[docs]
def remove_pows(expr):
"""
Convert pows to multiplications: x^3 -> x*x*x
"""
pows = list(expr.atoms(sm.Pow))
repl = [
sm.Mul(*[b] * e, evaluate=False) for b, e in [i.as_base_exp() for i in pows]
]
e = non_eval_xreplace(expr, dict(zip(pows, repl)))
return e
[docs]
def evalf_str(x, n, literal="", significant=True):
"""
Call evalf on x up to n-th decimal and removes zeros
if significant is set.
"""
x = x.evalf(n).__str__()
if significant:
i = len(x)
while i > 1 and x[i - 1] == "0":
i -= 1
if i > 1:
x = x[: i + 1]
return x + literal
[docs]
def factor_split(
expr,
variables,
constant_var=None,
include_var=False,
init=0,
_factor=True,
_handle_const=True,
):
"""
Factorize and split expresssion.
"""
expr = expr.expand()
factors = {}
for var in variables:
factors[var] = init
for arg in expr.args:
I = arg.atoms(sm.Symbol).intersection(variables)
if len(I) == 0:
continue
elif len(I) == 1:
var = I.pop()
if not include_var:
arg = arg.xreplace({var: 1})
factors[var] += arg
else:
assert (
False
), f"Expression containing two or more variables!\n{arg} contains {I}."
return factors
[docs]
def build_eqs_from_dicts(d0, d1):
"""
Build equations from two dictionnaries (lhs and rhs).
ie. rhs-lhs = 0
For keys that are only present in one dictionary, the other operand is set to 0.
"""
treated = []
eqs = []
for k in d0.keys():
expr0 = d0[k]
expr1 = d1[k] if k in d1.keys() else 0
de = expr1 - expr0
if de != 0:
eqs.append(de)
treated.append(k)
for k in d1.keys():
if k not in treated:
expr0 = 0
expr1 = d1[k]
de = expr1 - expr0
if de != 0:
eqs.append(de)
return eqs
[docs]
def recurse_expression_tree(op, expr):
"""
Recurse through sympy expression tree and apply op on each subexpression.
"""
op(expr)
if isinstance(expr, sm.Expr):
for arg in expr.args:
recurse_expression_tree(op, arg)
[docs]
def get_derivative_variables(expr):
"""Derivative arguments changed for sympy 1.2"""
assert isinstance(expr, sm.Derivative)
from sympy.core import containers
if isinstance(expr.args[1], containers.Tuple):
# sympy >= 1.2 arguments are (variable, count)
# args = [(x0,3), (x1,1)]
_vars = tuple(v[0] for v in expr.args[1:] for _ in range(v[1]))
else:
# sympy < 1.2 arguments are repeated
# args=[x0, x0, x0, x1]
_vars = tuple(expr.args[1:])
return _vars
[docs]
class SetupExprI:
"""Interface for setupable expressions."""
[docs]
def setup(self, work):
raise NotImplementedError